# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
from mmdet.utils import ConfigType, OptMultiConfig
from ..utils import make_divisible, make_round
from mmengine.model import constant_init, kaiming_init
from ..layers import ADown, RepNCSPELAN4, SPPELAN, Silence

from mmyolo.registry import MODELS
from ..layers import CSPLayerWithTwoConv, SPPFBottleneck, SPPELANBottleneck
from ..utils import make_divisible, make_round
from .base_backbone import BaseBackbone

# from .gelean import SPPELAN, ConvWithCBAM, ConvWithEfficientViT

## csp_darknet.py


@MODELS.register_module()
class YOLOv5CSPDarknet(BaseBackbone):
    """CSP-Darknet backbone used in YOLOv5.
    Args:
        arch (str): Architecture of CSP-Darknet, from {P5, P6}.
            Defaults to P5.
        plugins (list[dict]): List of plugins for stages, each dict contains:
            - cfg (dict, required): Cfg dict to build plugin.
            - stages (tuple[bool], optional): Stages to apply plugin, length
              should be same as 'num_stages'.
        deepen_factor (float): Depth multiplier, multiply number of
            blocks in CSP layer by this amount. Defaults to 1.0.
        widen_factor (float): Width multiplier, multiply number of
            channels in each layer by this amount. Defaults to 1.0.
        input_channels (int): Number of input image channels. Defaults to: 3.
        out_indices (Tuple[int]): Output from which stages.
            Defaults to (2, 3, 4).
        frozen_stages (int): Stages to be frozen (stop grad and set eval
            mode). -1 means not freezing any parameters. Defaults to -1.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Defaults to dict(type='BN', requires_grad=True).
        act_cfg (dict): Config dict for activation layer.
            Defaults to dict(type='SiLU', inplace=True).
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Defaults to False.
        init_cfg (Union[dict,list[dict]], optional): Initialization config
            dict. Defaults to None.
    Example:
        >>> from mmyolo.models import YOLOv5CSPDarknet
        >>> import torch
        >>> model = YOLOv5CSPDarknet()
        >>> model.eval()
        >>> inputs = torch.rand(1, 3, 416, 416)
        >>> level_outputs = model(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        ...
        (1, 256, 52, 52)
        (1, 512, 26, 26)
        (1, 1024, 13, 13)
    """
    # From left to right:
    # in_channels, out_channels, num_blocks, add_identity, use_spp
    arch_settings = {
        'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
               [256, 512, 9, True, False], [512, 1024, 3, True, True]],
        'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
               [256, 512, 9, True, False], [512, 768, 3, True, False],
               [768, 1024, 3, True, True]]
    }

    def __init__(self,
                 arch: str = 'P5',
                 plugins: Union[dict, List[dict]] = None,
                 deepen_factor: float = 1.0,
                 widen_factor: float = 1.0,
                 input_channels: int = 3,
                 out_indices: Tuple[int] = (2, 3, 4),
                 frozen_stages: int = -1,
                 norm_cfg: ConfigType = dict(
                     type='BN', momentum=0.03, eps=0.001),
                 act_cfg: ConfigType = dict(type='SiLU', inplace=True),
                 norm_eval: bool = False,
                 init_cfg: OptMultiConfig = None):
        super().__init__(
            self.arch_settings[arch],
            deepen_factor,
            widen_factor,
            input_channels=input_channels,
            out_indices=out_indices,
            plugins=plugins,
            frozen_stages=frozen_stages,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            norm_eval=norm_eval,
            init_cfg=init_cfg)

    def build_stem_layer(self) -> nn.Module:
        """Build a stem layer."""
        return ConvModule(
            self.input_channels,
            make_divisible(self.arch_setting[0][0], self.widen_factor),
            kernel_size=6,
            stride=2,
            padding=2,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

    def build_stage_layer(self, stage_idx: int, setting: list) -> list:
        """Build a stage layer.

        Args:
            stage_idx (int): The index of a stage layer.
            setting (list): The architecture setting of a stage layer.
        """
        in_channels, out_channels, num_blocks, add_identity, use_spp = setting

        in_channels = make_divisible(in_channels, self.widen_factor)
        out_channels = make_divisible(out_channels, self.widen_factor)
        num_blocks = make_round(num_blocks, self.deepen_factor)
        stage = []
        conv_layer = ConvModule(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        stage.append(conv_layer)
        csp_layer = CSPLayer(
            out_channels,
            out_channels,
            num_blocks=num_blocks,
            add_identity=add_identity,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        stage.append(csp_layer)
        if use_spp:
            spp = SPPFBottleneck(
                out_channels,
                out_channels,
                kernel_sizes=5,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg)
            stage.append(spp)
        return stage

    def init_weights(self):
        """Initialize the parameters."""
        if self.init_cfg is None:
            for m in self.modules():
                if isinstance(m, torch.nn.Conv2d):
                    # In order to be consistent with the source code,
                    # reset the Conv2d initialization parameters
                    m.reset_parameters()
        else:
            super().init_weights()


@MODELS.register_module()
class YOLOv8CSPDarknet(BaseBackbone):
    print("this YOLOv8CSPDarknet is from the ENV MARK-Y2")
    """CSP-Darknet backbone used in YOLOv8.

    Args:
        arch (str): Architecture of CSP-Darknet, from {P5}.
            Defaults to P5.
        last_stage_out_channels (int): Final layer output channel.
            Defaults to 1024.
        plugins (list[dict]): List of plugins for stages, each dict contains:
            - cfg (dict, required): Cfg dict to build plugin.
            - stages (tuple[bool], optional): Stages to apply plugin, length
              should be same as 'num_stages'.
        deepen_factor (float): Depth multiplier, multiply number of
            blocks in CSP layer by this amount. Defaults to 1.0.
        widen_factor (float): Width multiplier, multiply number of
            channels in each layer by this amount. Defaults to 1.0.
        input_channels (int): Number of input image channels. Defaults to: 3.
        out_indices (Tuple[int]): Output from which stages.
            Defaults to (2, 3, 4).
        frozen_stages (int): Stages to be frozen (stop grad and set eval
            mode). -1 means not freezing any parameters. Defaults to -1.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Defaults to dict(type='BN', requires_grad=True).
        act_cfg (dict): Config dict for activation layer.
            Defaults to dict(type='SiLU', inplace=True).
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Defaults to False.
        init_cfg (Union[dict,list[dict]], optional): Initialization config
            dict. Defaults to None.

    Example:
        >>> from mmyolo.models import YOLOv8CSPDarknet
        >>> import torch
        >>> model = YOLOv8CSPDarknet()
        >>> model.eval()
        >>> inputs = torch.rand(1, 3, 416, 416)
        >>> level_outputs = model(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        ...
        (1, 256, 52, 52)
        (1, 512, 26, 26)
        (1, 1024, 13, 13)
    """
    # From left to right:
    # in_channels, out_channels, num_blocks, add_identity, use_spp
    # the final out_channels will be set according to the param.
    arch_settings = {
        # 'P5': [
        #     # [in_channels, out_channels, num_blocks, add_identity, use_adown]
        #     [64, 128, 1, True, True],    # stage0 (P3/8)
        #     [128, 256, 1, True, True],   # stage1 (P4/16)
        #     [256, 512, 1, True, True],   # stage2 (P5/32)
        #     [512, 1024, 1, True, False] # stage3 (final)
        # ]

        # 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
        #        [256, 512, 9, True, False], [512, 1024, 3, True, True]]

        'P5': [[64, 128, 3, False, False], [128, 256, 6, False, False],
               [256, 512, 6, True, False], [512, 0, 3, True, True]],

        # 'P5': [
        #     # Format: [in_channels, c2, c3, c4, num_blocks, use_adown]
        #     [128, 256, 128, 64, 1, True],   # Stage3 (160x160→80x80)
        #     [256, 512, 256, 128, 1, True],  # Stage5 (80x80→40x40)
        #     [512, 512, 512, 256, 1, True],  # Stage7 (40x40→20x20)
        #     [512, 512, 512, 256, 1, False]  # Stage9 (20x20 final)
        # ]

    }

    def __init__(self,
                 arch: str = 'P5',
                 last_stage_out_channels: int = 1024,
                 plugins: Union[dict, List[dict]] = None,
                 deepen_factor: float = 1.0,
                 widen_factor: float = 1.0,
                 input_channels: int = 3,
                 out_indices: Tuple[int] = (2, 3, 4),
                 frozen_stages: int = -1,
                 norm_cfg: ConfigType = dict(
                     type='BN', momentum=0.03, eps=0.001),
                 act_cfg: ConfigType = dict(type='SiLU', inplace=True),
                 norm_eval: bool = False,
                 init_cfg: OptMultiConfig = None):
        self.arch_settings[arch][-1][1] = last_stage_out_channels
        super().__init__(
            self.arch_settings[arch],
            deepen_factor,
            widen_factor,
            input_channels=input_channels,
            out_indices=out_indices,
            plugins=plugins,
            frozen_stages=frozen_stages,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            norm_eval=norm_eval,
            init_cfg=init_cfg)

    def build_stem_layer(self) -> nn.Module:
        """Build a stem layer."""
        # return nn.Sequential(
        #     ConvModule(
        #         self.input_channels,
        #         make_divisible(64, self.widen_factor),  # First stem to 64 channels
        #         kernel_size=3,
        #         stride=2,
        #         padding=1,
        #         norm_cfg=self.norm_cfg,
        #         act_cfg=self.act_cfg
        #     ),
        #     ConvModule(
        #         make_divisible(64, self.widen_factor),
        #         make_divisible(128, self.widen_factor),  # Second stem to 128 channels
        #         kernel_size=3,
        #         stride=2,
        #         padding=1,
        #         norm_cfg=self.norm_cfg,
        #         act_cfg=self.act_cfg
        #     )
        # )

        return ConvModule(
            self.input_channels,
            make_divisible(self.arch_setting[0][0], self.widen_factor),
            kernel_size=3,
            stride=2,
            padding=1,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

    def build_stage_layer(self, stage_idx: int, setting: list) -> list:
        """Build a stage layer.

        Args:
            stage_idx (int): The index of a stage layer.
            setting (list): The architecture setting of a stage layer.
        """
        in_channels, out_channels, num_blocks, use_adown , use_spp = setting

        in_channels = make_divisible(in_channels, self.widen_factor)
        out_channels = make_divisible(out_channels, self.widen_factor)
        num_blocks = make_round(num_blocks, self.deepen_factor)
        expand_ratio = float(0.5)

        in_channels = int(in_channels * self.widen_factor)
        out_channels = int(out_channels * self.widen_factor)
        num_blocks = make_round(num_blocks * self.deepen_factor)

        mid_channels = int(out_channels * expand_ratio)       # c3
        rep_channels = int(mid_channels * expand_ratio)          # c4

        stage = []

        # ========= TESTING 
        
        # Unpack all parameters from arch_settings
        # in_channels, c2, c3, c4, num_blocks, use_adown = setting

        # # Apply width scaling
        # in_channels = make_divisible(in_channels, self.widen_factor)
        # c2 = make_divisible(c2, self.widen_factor)
        # c3 = make_divisible(c3, self.widen_factor)
        # c4 = make_divisible(c4, self.widen_factor)
        # num_blocks = make_round(num_blocks, self.deepen_factor)


        # stage = []

        # 1. Always add RepNCSPELAN4 first
        # rep_layer = RepNCSPELAN4(
        #     c1=in_channels,
        #     c2=c2,
        #     c3=c3,
        #     c4=c4,
        #     c5=num_blocks,
        # )
        # stage.append(rep_layer)

        # 2. Add ADown after RepNCSPELAN4 if needed
        # if use_adown:
        #     adown_layer = ADown(
        #         c1=c2,
        #         c2=c2  # Maintain same channel count
        #     )
        #     stage.append(adown_layer)

        if use_adown:
            sc_down_layer = ADown(in_channels, 
                out_channels)
            stage.append(sc_down_layer)
        else:
            conv_layer = ConvModule(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg)
            stage.append(conv_layer)

        # conv_layer = ConvModule(
        #     in_channels,
        #     out_channels,
        #     kernel_size=3,
        #     stride=2,
        #     padding=1,
        #     norm_cfg=self.norm_cfg,
        #     act_cfg=self.act_cfg)
        # stage.append(conv_layer)

        # rep_layer = RepNCSPELAN4(
        #     out_channels,        # c1
        #     out_channels,        # c2
        #     num_blocks=num_blocks,           # c5
        #     norm_cfg=self.norm_cfg,
        #     act_cfg=self.act_cfg
        # )
        rep_layer = RepNCSPELAN4(
                in_channels,
                out_channels,
                mid_channels,
                rep_channels
            )
        stage.append(rep_layer)
        # adown_layber = ADown(
        #     in_channels,
        #     out_channels,
        # )
        # stage.append(adown_layber)
        # csp_layer = CSPLayerWithTwoConv(
        #     out_channels,
        #     out_channels,
        #     num_blocks=num_blocks,
        #     add_identity=add_identity,
        #     norm_cfg=self.norm_cfg,
        #     act_cfg=self.act_cfg)
        # stage.append(csp_layer)
        if use_spp:
            spp = SPPELAN(
                out_channels,
                out_channels,
                768 // 2 
            )
            stage.append(spp)
        return stage

    def init_weights(self):
        """Initialize the parameters."""
        if self.init_cfg is None:
            for m in self.modules():
                if isinstance(m, torch.nn.Conv2d):
                    # In order to be consistent with the source code,
                    # reset the Conv2d initialization parameters
                    m.reset_parameters()
        else:
            super().init_weights()


# class YOLOv8CSPDarknet(BaseBackbone):
#     print(" THIS IS IN THE BASE YOLOV8")
#     """CSP-Darknet backbone used in YOLOv8.

#     Args:
#         arch (str): Architecture of CSP-Darknet, from {P5}.
#             Defaults to P5.
#         last_stage_out_channels (int): Final layer output channel.
#             Defaults to 1024.
#         plugins (list[dict]): List of plugins for stages, each dict contains:
#             - cfg (dict, required): Cfg dict to build plugin.
#             - stages (tuple[bool], optional): Stages to apply plugin, length
#               should be same as 'num_stages'.
#         deepen_factor (float): Depth multiplier, multiply number of
#             blocks in CSP layer by this amount. Defaults to 1.0.
#         widen_factor (float): Width multiplier, multiply number of
#             channels in each layer by this amount. Defaults to 1.0.
#         input_channels (int): Number of input image channels. Defaults to: 3.
#         out_indices (Tuple[int]): Output from which stages.
#             Defaults to (2, 3, 4).
#         frozen_stages (int): Stages to be frozen (stop grad and set eval
#             mode). -1 means not freezing any parameters. Defaults to -1.
#         norm_cfg (dict): Dictionary to construct and config norm layer.
#             Defaults to dict(type='BN', requires_grad=True).
#         act_cfg (dict): Config dict for activation layer.
#             Defaults to dict(type='SiLU', inplace=True).
#         norm_eval (bool): Whether to set norm layers to eval mode, namely,
#             freeze running stats (mean and var). Note: Effect on Batch Norm
#             and its variants only. Defaults to False.
#         init_cfg (Union[dict,list[dict]], optional): Initialization config
#             dict. Defaults to None.

#     Example:
#         >>> from mmyolo.models import YOLOv8CSPDarknet
#         >>> import torch
#         >>> model = YOLOv8CSPDarknet()
#         >>> model.eval()
#         >>> inputs = torch.rand(1, 3, 416, 416)
#         >>> level_outputs = model(inputs)
#         >>> for level_out in level_outputs:
#         ...     print(tuple(level_out.shape))
#         ...
#         (1, 256, 52, 52)
#         (1, 512, 26, 26)
#         (1, 1024, 13, 13)
#     """
#     # From left to right:
#     # in_channels, out_channels, num_blocks, add_identity, use_spp
#     # the final out_channels will be set according to the param.
#     arch_settings = {
#         'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
#                [256, 512, 6, True, False], [512, None, 3, True, True]],
#     }

#     def __init__(self,
#                  arch: str = 'P5',
#                  last_stage_out_channels: int = 1024,
#                  plugins: Union[dict, List[dict]] = None,
#                  deepen_factor: float = 1.0,
#                  widen_factor: float = 1.0,
#                  input_channels: int = 3,
#                  out_indices: Tuple[int] = (2, 3, 4),
#                  frozen_stages: int = -1,
#                  norm_cfg: ConfigType = dict(
#                      type='BN', momentum=0.03, eps=0.001),
#                  act_cfg: ConfigType = dict(type='SiLU', inplace=True),
#                  norm_eval: bool = False,
#                  init_cfg: OptMultiConfig = None):
#         self.arch_settings[arch][-1][1] = last_stage_out_channels
#         super().__init__(
#             self.arch_settings[arch],
#             deepen_factor,
#             widen_factor,
#             input_channels=input_channels,
#             out_indices=out_indices,
#             plugins=plugins,
#             frozen_stages=frozen_stages,
#             norm_cfg=norm_cfg,
#             act_cfg=act_cfg,
#             norm_eval=norm_eval,
#             init_cfg=init_cfg)

#     def build_stem_layer(self) -> nn.Module:
#         """Build a stem layer."""
#         return ConvModule(
#             self.input_channels,
#             make_divisible(self.arch_setting[0][0], self.widen_factor),
#             kernel_size=3,
#             stride=2,
#             padding=1,
#             norm_cfg=self.norm_cfg,
#             act_cfg=self.act_cfg)

#     def build_stage_layer(self, stage_idx: int, setting: list) -> list:
#         """Build a stage layer.

#         Args:
#             stage_idx (int): The index of a stage layer.
#             setting (list): The architecture setting of a stage layer.
#         """
#         in_channels, out_channels, num_blocks, add_identity, use_spp = setting

#         in_channels = make_divisible(in_channels, self.widen_factor)
#         out_channels = make_divisible(out_channels, self.widen_factor)
#         num_blocks = make_round(num_blocks, self.deepen_factor)
#         stage = []
#         conv_layer = ConvModule(
#             in_channels,
#             out_channels,
#             kernel_size=3,
#             stride=2,
#             padding=1,
#             norm_cfg=self.norm_cfg,
#             act_cfg=self.act_cfg)
#         stage.append(conv_layer)
#         csp_layer = CSPLayerWithTwoConv(
#             out_channels,
#             out_channels,
#             num_blocks=num_blocks,
#             add_identity=add_identity,
#             norm_cfg=self.norm_cfg,
#             act_cfg=self.act_cfg)
#         stage.append(csp_layer)
#         if use_spp:
#             spp = SPPFBottleneck(
#                 out_channels,
#                 out_channels,
#                 kernel_sizes=5,
#                 norm_cfg=self.norm_cfg,
#                 act_cfg=self.act_cfg)
#             stage.append(spp)
#         return stage

#     def init_weights(self):
#         """Initialize the parameters."""
#         if self.init_cfg is None:
#             for m in self.modules():
#                 if isinstance(m, torch.nn.Conv2d):
#                     # In order to be consistent with the source code,
#                     # reset the Conv2d initialization parameters
#                     m.reset_parameters()
#         else:
#             super().init_weights()

@MODELS.register_module()
class YOLOXCSPDarknet(BaseBackbone):
    """CSP-Darknet backbone used in YOLOX.

    Args:
        arch (str): Architecture of CSP-Darknet, from {P5, P6}.
            Defaults to P5.
        plugins (list[dict]): List of plugins for stages, each dict contains:

            - cfg (dict, required): Cfg dict to build plugin.
            - stages (tuple[bool], optional): Stages to apply plugin, length
              should be same as 'num_stages'.
        deepen_factor (float): Depth multiplier, multiply number of
            blocks in CSP layer by this amount. Defaults to 1.0.
        widen_factor (float): Width multiplier, multiply number of
            channels in each layer by this amount. Defaults to 1.0.
        input_channels (int): Number of input image channels. Defaults to 3.
        out_indices (Tuple[int]): Output from which stages.
            Defaults to (2, 3, 4).
        frozen_stages (int): Stages to be frozen (stop grad and set eval
            mode). -1 means not freezing any parameters. Defaults to -1.
        use_depthwise (bool): Whether to use depthwise separable convolution.
            Defaults to False.
        spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
            layers. Defaults to (5, 9, 13).
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Defaults to dict(type='BN', momentum=0.03, eps=0.001).
        act_cfg (dict): Config dict for activation layer.
            Defaults to dict(type='SiLU', inplace=True).
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
        init_cfg (Union[dict,list[dict]], optional): Initialization config
            dict. Defaults to None.
    Example:
        >>> from mmyolo.models import YOLOXCSPDarknet
        >>> import torch
        >>> model = YOLOXCSPDarknet()
        >>> model.eval()
        >>> inputs = torch.rand(1, 3, 416, 416)
        >>> level_outputs = model(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        ...
        (1, 256, 52, 52)
        (1, 512, 26, 26)
        (1, 1024, 13, 13)
    """
    # From left to right:
    # in_channels, out_channels, num_blocks, add_identity, use_spp
    arch_settings = {
        'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
               [256, 512, 9, True, False], [512, 1024, 3, False, True]],
    }

    def __init__(self,
                 arch: str = 'P5',
                 plugins: Union[dict, List[dict]] = None,
                 deepen_factor: float = 1.0,
                 widen_factor: float = 1.0,
                 input_channels: int = 3,
                 out_indices: Tuple[int] = (2, 3, 4),
                 frozen_stages: int = -1,
                 use_depthwise: bool = False,
                 spp_kernal_sizes: Tuple[int] = (5, 9, 13),
                 norm_cfg: ConfigType = dict(
                     type='BN', momentum=0.03, eps=0.001),
                 act_cfg: ConfigType = dict(type='SiLU', inplace=True),
                 norm_eval: bool = False,
                 init_cfg: OptMultiConfig = None):
        self.use_depthwise = use_depthwise
        self.spp_kernal_sizes = spp_kernal_sizes
        super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
                         input_channels, out_indices, frozen_stages, plugins,
                         norm_cfg, act_cfg, norm_eval, init_cfg)

    def build_stem_layer(self) -> nn.Module:
        """Build a stem layer."""
        return Focus(
            3,
            make_divisible(64, self.widen_factor),
            kernel_size=3,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

    def build_stage_layer(self, stage_idx: int, setting: list) -> list:
        """Build a stage layer.

        Args:
            stage_idx (int): The index of a stage layer.
            setting (list): The architecture setting of a stage layer.
        """
        in_channels, out_channels, num_blocks, add_identity, use_spp = setting

        in_channels = make_divisible(in_channels, self.widen_factor)
        out_channels = make_divisible(out_channels, self.widen_factor)
        num_blocks = make_round(num_blocks, self.deepen_factor)
        stage = []
        conv = DepthwiseSeparableConvModule \
            if self.use_depthwise else ConvModule
        conv_layer = conv(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        stage.append(conv_layer)
        if use_spp:
            spp = SPPFBottleneck(
                out_channels,
                out_channels,
                kernel_sizes=self.spp_kernal_sizes,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg)
            stage.append(spp)
        csp_layer = CSPLayer(
            out_channels,
            out_channels,
            num_blocks=num_blocks,
            add_identity=add_identity,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        stage.append(csp_layer)
        return stage
